# -*- coding: utf-8 -*-
"""
Created on Sun Jul  3 10:36:41 2022

@author: asus
"""

import numpy as np
from scipy import sparse
import pandas as pd
import networkx as nx
import matplotlib.pylab as mpl
# mpl.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import matplotlib.ticker
import seaborn as sns
from matplotlib.cm import ScalarMappable
import random
import os
import gc
from tqdm import tqdm
import shutil
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.nn.functional as F
from data_utils import PrepareConcDataset,gnn2cnn
from shap_utils import get_topk_array

config = {
"font.family": 'serif', 
"font.size": 16,
"font.serif": ['SimSun'], # 
"mathtext.fontset": 'stix', 
'axes.unicode_minus': False # 
}
plt.rcParams.update(config)
plt.rcParams['xtick.direction'] = 'in'  # in; out; inout
plt.rcParams['ytick.direction'] = 'in'



#%%padding function

def get_input_time(last_input,inverse_day):
    '''
    last_input shape is B_T_C_H_W
    '''
    time_gap=last_input[0,1,0,0,0]-last_input[0,0,0,0,0]
    time_list=np.arange(time_gap,(inverse_day+1)*time_gap,time_gap)+last_input[-1,-1,0,0,0]
    return time_list


def get_input_padding(last_input,inverse_day):
    '''
    last_input shape is B_T_C_H_W
    '''
    temp_input_data=[]
    time_list=get_input_time(last_input,inverse_day)
    
    for i in range(1,inverse_day+1):
        temp=np.roll(last_input,-i,axis=1)
        temp=get_loop_padding(temp, i,temp[:,-1-i,:,:,:])
        temp=get_time_padding(temp,i,time_list)
        temp_input_data.append(temp)
    
    temp_input_data=np.concatenate(temp_input_data,axis=0)
    return temp_input_data


def get_time_padding(array,length,time_list):
    '''
    array shape is B_T_C_H_W
    '''
    for i in range(1,length+1):
        array[:,-i,0,:,:]=time_list[length-i]*np.ones((array.shape[0],array.shape[3],array.shape[4]))
    
    return array


def get_loop_padding(array1,length,array2):
    for i in range(1,length+1):
        array1[:,-i,:,:,:]=array2
    return array1


def pad_ensemble2input(ensemble,temp_input,tot_sense_well,):
    '''
    ensemble shape is B_T*F*N=7_10
    temp_inputshape is B_T_C_H_W
    '''
    for i in range(ensemble.shape[0]):
        for j in range(i,temp_input.shape[0]):
            for k in range(tot_sense_well.shape[1]):

                temp_input[j,i-j-1,tot_sense_well[i,k,1]+1,tot_sense_well[i,k,0]]=ensemble[i,k]
                    
    return temp_input


def get_nonzero_matrix(array,padding):
    array1=[]
    for i in range(array.shape[0]):
        temp_array=[]
        for j in range(array.shape[1]):
            if array[i,j]==padding:
                pass
            else:
                temp_array.append(array[i,j])
        array1.append(temp_array)
    array1=np.array(array1)
    
    return array1


def get_sense_well(peak_eg_values,edge_list,topk,):
        
    sense_shap=[]
    sense_well=[]
    for i in range(peak_eg_values.shape[0]):
        temp=get_topk_array(peak_eg_values[i,-1,:,:],topk)

        temp_well=np.where(temp!=0)
        sense_shap.append(temp)
        sense_well.append(np.stack(temp_well,axis=1))
    

    tot_sense_well=np.stack(sense_well,axis=0)
    for i in range(tot_sense_well.shape[0]):
        sense_index=np.argsort(tot_sense_well[i][:,1])
        tot_sense_well[i]=tot_sense_well[i][sense_index]

    return tot_sense_well

#%%

def cnn2gnn5(array,edge_list):
    '''
    array shape is B_T_C_H_W
    edge_list shape is Edge_3
    '''
    new_array=[]
    index_node=np.array(edge_list[:,1:],dtype=np.int32)
    for i in range(edge_list.shape[0]):
        new_array.append(array[:,:,:,index_node[i,0],index_node[i,1]])
    new_array=np.stack(new_array,axis=-1)
    return new_array
        

def get_l1loss(output,target):
    loss=abs(output-target).mean()
    return loss


def detect_bound(ensemble,inverse_flux_well,ies_args):
    '''
    ensemble shape is E_T*B*N
    '''
    ensemble=ensemble.reshape(ies_args['num_ensemble'],ies_args['inverse_day'],-1)
    inverse_flux_well=np.tile(inverse_flux_well,(ies_args['num_ensemble'],1,1))
    inject_lower_bound=ies_args['inject_lower_bound']*np.ones_like(ensemble[:,:,:5])
    inject_upper_bound=ies_args['inject_upper_bound']*np.ones_like(ensemble[:,:,:5])
    pump_lower_bound=ies_args['pump_lower_bound']*np.ones_like(ensemble[:,:,:5])
    pump_upper_bound=ies_args['pump_upper_bound']*np.ones_like(ensemble[:,:,:5])
    H_lower_bound=ies_args['H_lower_bound']*np.ones_like(ensemble[:,:,5:])
    H_upper_bound=ies_args['H_upper_bound']*np.ones_like(ensemble[:,:,5:])
    
    ensemble[:,:,:5]=np.where((ensemble[:,:,:5]<inject_lower_bound)&(inverse_flux_well>=20),inject_lower_bound,ensemble[:,:,:5])
    ensemble[:,:,:5]=np.where((ensemble[:,:,:5]>inject_upper_bound)&(inverse_flux_well>=20),inject_upper_bound,ensemble[:,:,:5])
    ensemble[:,:,:5]=np.where((ensemble[:,:,:5]<pump_lower_bound)&(inverse_flux_well<20),pump_lower_bound,ensemble[:,:,:5])
    ensemble[:,:,:5]=np.where((ensemble[:,:,:5]>pump_upper_bound)&(inverse_flux_well<20),pump_upper_bound,ensemble[:,:,:5])
    ensemble[:,:,5:]=np.where(ensemble[:,:,5:]<H_lower_bound,H_lower_bound,ensemble[:,:,5:])
    ensemble[:,:,5:]=np.where(ensemble[:,:,5:]>H_upper_bound,H_upper_bound,ensemble[:,:,5:])
    
    ensemble=ensemble.reshape(ies_args['num_ensemble'],ies_args['inverse_day']*ensemble.shape[-1])
    return ensemble
    

def configure_ies(ies_args,last_input,temp_target,tot_sense_well,inverse_flux_well):
    '''
    last_input shape is B_T_C_N
    temp_target shape is B_N
    '''
    ########################################################
    obser=temp_target
    measurement=obser.flatten()#B*Fout
    
    [batch_size,time_len,input_channel,num_node]=last_input.shape
    input_channel=input_channel-1
    output_channel=obser.shape[-1]
    
    inverse_day=ies_args['inverse_day']

    prior_flux_mean=np.stack([last_input[0,-1,1,inverse_flux_well[i,:]] for i in range(inverse_flux_well.shape[0])],axis=0)
    prior_flux_std=np.zeros_like(prior_flux_mean)
    for i in range(inverse_flux_well.shape[0]):
        for j in range(inverse_flux_well.shape[1]):
            if inverse_flux_well[i,j]>=20:
                prior_flux_std[i,j]=np.std(last_input[0,:,1,20:])*ies_args['std_diffuse']
            else:
                prior_flux_std[i,j]=np.std(last_input[0,:,1,:20])*ies_args['std_diffuse']
    
    ensemble=np.random.randn(ies_args['num_ensemble'],ies_args['inverse_day'],tot_sense_well.shape[1])#E_B*(Fin-1)
    temp_flux=np.random.normal(prior_flux_mean,prior_flux_std,ensemble[:,:,:5].shape)
    ensemble[:,:,:5]=temp_flux
    ensemble=ensemble.reshape(ies_args['num_ensemble'],ies_args['inverse_day']*tot_sense_well.shape[1])
    
    ensemble=detect_bound(ensemble,inverse_flux_well,ies_args)
    principal_sqrtR=np.diag(np.ones(output_channel))
    
    return obser,measurement,ensemble,principal_sqrtR,input_channel,output_channel


def ensemble2input(ensemble,inverse_day,tot_sense_well,temp_input,edge_list,count,scaler_Q,scaler_H):
    padded_input=[]
    for j in range(ensemble.shape[0]):
        temp_ensemble=ensemble[j,:].reshape(inverse_day,tot_sense_well.shape[1])
        temp_ensemble=pad_ensemble2input(temp_ensemble,temp_input,tot_sense_well,)
        temp_ensemble=gnn2cnn(temp_ensemble,edge_list,count,['same',(0-scaler_Q.mean_)/np.sqrt(scaler_Q.var_),
                                                             (0-scaler_H.mean_)/np.sqrt(scaler_H.var_)])
        padded_input.append(temp_ensemble)
    
    return padded_input



#%%

def get_output_loss(ensemble,measurement,temp_input,tot_sense_well,edge_list,count,scaler_Q,scaler_H,
                    inverse_day,input_channel,model,wbase,ne):
    '''
    ensemble shape is E_B*F*N
    measurement shape is B*F
    last_input shape is B_T_C_N
    '''

    padded_input=ensemble2input(ensemble,inverse_day,tot_sense_well,temp_input,edge_list,count,scaler_Q,scaler_H)

    padded_input=np.concatenate(padded_input,axis=0)
    ensemble_tensor=torch.FloatTensor(padded_input)
    sim_data=model(ensemble_tensor).detach().numpy().squeeze(-1).reshape(ensemble.shape[0],inverse_day)
    sim_data=sim_data/np.tile(wbase,(sim_data.shape[0],1))



    # sim_data=[]
    # for i in range(len(padded_input)):
    #     ensemble_tensor=torch.FloatTensor(padded_input[i])
    #     sim_data.append(model(ensemble_tensor).detach().numpy().T)
    # sim_data=np.vstack(sim_data)
    # sim_data=sim_data/np.tile(wbase,(sim_data.shape[0],1))

    
    obj=get_l1loss(sim_data[:ne,:],measurement)
    print('optimize objection=',obj)
    
    return sim_data,obj


#%%

def inner_iteration(ies_args,nd,ne,nf,input_channel,ensemble,measurement,sim_data,perturbed_data,wbase,model,
                    tot_sense_well,inverse_flux_well,edge_list,count,scaler_Q,scaler_H,inverse_day,temp_input,
                    ud,wd,vd,svdpd,deltaM,deltaD,obj,lambd,iterat):
    
    iter_lambd=1#inner interation
    is_min_rn=0
    max_inn_iter=ies_args['max_in_iter']
    lambd_reduct=ies_args['lambd_reduct']
    lambd_incre=ies_args['lambd_incre']
    do_tsvd=ies_args['do_tsvd']
    min_rn=ies_args['min_rn']
    
    while iter_lambd<max_inn_iter:
        
        print('*'*15,'inner interation step:',iter_lambd,'*'*15)
        
        ensemble_old=ensemble.copy()
        sim_data_old=sim_data.copy()
        
        if do_tsvd:
            alpha=lambd*np.sum(wd**2)/svdpd
            x1=vd@sparse.diags(wd/(wd**2+alpha),0,(svdpd,svdpd))
            kgain=deltaM.T@x1@ud.T
            
        else:
            alpha=lambd*sum(sum(deltaD**2))/nd
            kgain=deltaM@deltaD/(deltaD@deltaD.t()+alpha*np.eye(nd))
        
        iterated_ensemble=ensemble[:ne,:]-(sim_data[:ne,:]-perturbed_data)@kgain.T
        iterated_ensemble=detect_bound(iterated_ensemble,inverse_flux_well,ies_args)
        ensemble_mean=iterated_ensemble.mean(axis=0)
        ensemble=np.vstack([iterated_ensemble,ensemble_mean])
        
        m_change=np.sqrt(np.sum((ensemble[:ne,:]-ensemble_old[:ne,:])**2)/nf)
        print('average change (in RMSE) of ensemble mean=',m_change)
        
        sim_data,obj_new=get_output_loss(ensemble,measurement,temp_input,tot_sense_well,edge_list,count,scaler_Q,scaler_H,
                                         inverse_day,input_channel,model,wbase,ne)
        
        if obj_new>obj:
            lambd=lambd*lambd_incre
            print('lambd increase to',lambd)
            iter_lambd=iter_lambd+1
            sim_data=sim_data_old
            ensemble=ensemble_old
            
        else:
            lambd=lambd*lambd_reduct
            print('lambd reduce to',lambd)
            
            iterat=iterat+1
            
            if abs(obj_new-obj)/abs(obj)*100<min_rn:
                is_min_rn=1
                
            sim_data_old=sim_data
            ensemble_old=ensemble
            obj=obj_new
            break
    return iter_lambd,lambd,iterat,is_min_rn,ensemble,sim_data,obj
    
    
def outter_iteration(ies_args,nd,ne,nf,input_channel,init_obj,ensemble,measurement,sim_data,perturbed_data,wbase,model,
                     tot_sense_well,inverse_flux_well,edge_list,count,scaler_Q,scaler_H,inverse_day,temp_input):
    iterat=0
    obj=init_obj
    init_lambd=ies_args['init_lambd']
    lambd=ies_args['init_lambd']
    beta=ies_args['beta']
    obj_thresh=beta**2*nd
    max_out_iter=ies_args['max_out_iter']
    max_inn_iter=ies_args['max_in_iter']
    lambd_reduct=ies_args['lambd_reduct']
    lambd_incre=ies_args['lambd_incre']
    min_rn=ies_args['min_rn']
    max_lambd=ies_args['max_lambd']
    
    do_tsvd=ies_args['do_tsvd']
    tsvd_cut=ies_args['tsvd_cut']
    # flags of iES termination status; 1st => maxOuterIter; 2nd => objThreshold; 3rd => min_RN_change; 4th => max_lambd
    exit_flag=[0,0,0,0]
    objs=[]
    lambds=[]
    
    ########################################################
    while iterat<max_out_iter and obj>obj_thresh:
        
        print('-'*25,'outer iteration step:',iterat,'-'*25)
        print('number of measurement elements is ',measurement.size)
        

        deltaM=ensemble[:ne,:]-np.ones((ne,1))@ensemble[ne:,:]
        deltaD=sim_data[:ne,:]-np.ones((ne,1))@sim_data[ne:,:]
        
        if do_tsvd:
            ud,wd,vd=np.linalg.svd(deltaD.T,full_matrices=False)
            vd=vd.T
            wd=np.diag(wd)
            val=np.diag(wd)
            total=np.sum(val)
            for j in range(1,ne):
                svdpd=j
                if val[:j].sum()/total>tsvd_cut:
                    break
            
            print('svdpd=',svdpd)
            
            ud=ud[:,:svdpd]
            wd=val[:svdpd]
            vd=vd[:,:svdpd]
            
        iter_lambd,lambd,iterat,is_min_rn,ensemble,sim_data,obj=inner_iteration(ies_args,nd,ne,nf,input_channel,ensemble,measurement,sim_data,perturbed_data,wbase,model,
                                                                                tot_sense_well,inverse_flux_well,edge_list,count,scaler_Q,scaler_H,inverse_day,temp_input,
                                                                                ud,wd,vd,svdpd,deltaM,deltaD,obj,lambd,iterat)
        objs.append(obj)
        lambds.append(lambd)
        
        if iter_lambd>=max_inn_iter:
            
            lambd=lambd*lambd_incre
            if lambd<init_lambd:
                lambd=init_lambd
                
            iterat=iterat+1
            print('terminating inner iterations: iterLambda >= maxInnerIter')
            
        if is_min_rn:
            print('terminating outer iterations: reduction of objective function is less than ',min_rn)
            exit_flag[2]=1
            break
        
        if lambd>max_lambd:
            print('terminating outer iterations: lambd is bigger than ',max_lambd)
            exit_flag[3]=1
            break
        
    if iterat>=max_out_iter:
        print('terminating outer iterations: iter >= maxOuterIter')
        exit_flag[0]=1
        
    if obj<=obj_thresh:
        print('terminating outer iterations: obj <= objThreshold')
        exit_flag[1]=1
        
    print('exit_flag=',exit_flag)
    
    return ensemble,objs,lambds
        
        
#%%

def ies_main(ies_args,model,edge_list,count,mmin_list,dev_list,scaler_Q,scaler_H,
             inverse_flux_well,tot_sense_well,last_input,temp_input,temp_target):
    
    obser,measurement,ensemble,principal_sqrtR,\
    input_channel,output_channel=configure_ies(ies_args,last_input,temp_target,tot_sense_well,inverse_flux_well)
    inverse_day=ies_args['inverse_day']
    
    ########################################################
    nd=len(measurement)
    ne=ensemble.shape[0]
    
    ensemble_mean=ensemble.mean(axis=0)[np.newaxis,:]
    ensemble=np.vstack([ensemble,ensemble_mean])
    
    #B*Fout
    wbase=[]
    for i in range(obser.shape[0]):
        wbase.append(np.diag(principal_sqrtR))
    
    wbase=np.array(wbase).flatten()
    measurement=measurement/wbase
    
    perturbed_data=np.zeros((ne,nd))
    weight=ies_args['noise']*measurement
    for i in range(ne):
        # perturbed_data[i,:]=measurement+weight*np.random.randn(*measurement.shape)
        perturbed_data[i,:]=measurement+weight*np.random.uniform(-1,1,measurement.shape)
        
    
    ########################################################
    nf=ensemble.shape[1]
    sim_data,obj=get_output_loss(ensemble,measurement,temp_input,tot_sense_well,edge_list,count,scaler_Q,scaler_H,
                                 inverse_day,input_channel,model,wbase,ne)
    init_obj=obj
    ensemble,objs,lambds=outter_iteration(ies_args,nd,ne,nf,input_channel,init_obj,ensemble,measurement,sim_data,perturbed_data,wbase,model,
                                          tot_sense_well,inverse_flux_well,edge_list,count,scaler_Q,scaler_H,inverse_day,temp_input)
    
    ensemble_input=ensemble2input(ensemble,inverse_day,tot_sense_well,temp_input,edge_list,count,scaler_Q,scaler_H)
    ensemble_output,final_mae=get_output_loss(ensemble, measurement, temp_input, tot_sense_well, edge_list, count, scaler_Q, scaler_H, 
                                              inverse_day, input_channel, model, wbase, ne)
    objs=np.array(objs)
    lambds=np.array(lambds)
    return ensemble,ensemble_output,ensemble_input,objs,lambds


#%%

def plot_ensemble_output(ensemble_output,temp_target,base_target,real_uranium,loot_figure_path,log_dir,criterion1,i):
    '''
    ensemble_output shape is ne_T
    temp_target shape is T_1
    '''
    
    ne=ensemble_output.shape[0]-1
    output_mae=abs(ensemble_output[-1,:]-temp_target[:,0]).mean()
    blue_colors= mpl.cm.Blues(np.linspace(0,1,ne))
    fig_name=f'({chr(i+97)})'
    fig,axs = plt.subplots(1,1,figsize=(5,5),dpi=200, sharey=True)
    
    for i in range(ne):
        axs.plot(ensemble_output[i,:],c=blue_colors[i])
        if i==ne-1:
            axs.plot(ensemble_output[i,:],c=blue_colors[i],label='集成值')
        
    axs.plot(ensemble_output[-1,:],c='yellow',label='集成平均值')
    axs.plot(temp_target[:,0],c='green',label='目标值')
    # axs.plot(base_target[:,0],c='red',label='基线值')
    axs.plot(real_uranium,c='C1',label='实测值')
    axs.set_ylim(40,60)
    axs.set_title(f'{fig_name}目标值={criterion1} (kg/d)')
    axs.set_xlabel('时间(d)')
    axs.set_ylabel('总铀(kg/d)')
    axs.text(1.03,0.45,f'误差={output_mae:.2f} (kg/d)', 
                transform=axs.transAxes)
    
    axs.legend(loc='upper left',bbox_to_anchor=(1.01,1.02),framealpha=0.9)

    path_t=loot_figure_path+log_dir+f'/inverse_error_deviation{criterion1}.png'
    plt.savefig(path_t, bbox_inches='tight',dpi=200)
    plt.clf()
    plt.close()


def plot_convergence(objs,lambds,loot_figure_path,log_dir,criterion1):
    fig,axs = plt.subplots(1,2,figsize=(10,5),dpi=200, sharex=True)
    axs[0].plot(objs)
    axs[1].plot(lambds)
    
    axs[0].set_title('optimize objective curve')
    axs[1].set_title('lambda curve')
    
    path_t=loot_figure_path+log_dir+f'/objective_and_lambda{criterion1}.png'
    plt.savefig(path_t, bbox_inches='tight',dpi=200)
    plt.clf()
    plt.close()


def plot_pj_ratio(ies_args,tot_pj_ratio,real_pj_ratio,loot_figure_path,log_dir,criterion1):
    '''
    tot_pj_ratio is a list which contains [inverse_day,time_len]
    '''
    inverse_day=ies_args['inverse_day']
    fig,axs = plt.subplots(1,1,figsize=(5,5),dpi=200)
    axs.plot(range(tot_pj_ratio[0][-1,:-inverse_day].shape[0]),tot_pj_ratio[0][-1,:-inverse_day],
             label='过去值')
    axs.plot(range(real_pj_ratio[:,:-inverse_day].shape[1],
                   real_pj_ratio.shape[1]),real_pj_ratio[0,-inverse_day:],
             label='实测值')
    
    for i in range(len(tot_pj_ratio)):
        axs.plot(range(tot_pj_ratio[i][-1,:-inverse_day].shape[0],
                       tot_pj_ratio[i].shape[1]),tot_pj_ratio[i][-1,-inverse_day:],'o',
                 label=f'目标值={criterion1[i]}')
    
    axs.set_xlabel('时间 (d)')
    axs.set_ylabel('采注比')
    axs.legend(loc='upper left',bbox_to_anchor=(1.01,1.02),framealpha=0.9)
    path_t=loot_figure_path+log_dir+f'/pj_ratio_all_criterion.png'
    plt.savefig(path_t, bbox_inches='tight',dpi=200)
    plt.clf()
    plt.close()



#%%
def plot_sig_ensemble(fig,ax,fig_names,feature_name,unit,i,j,
                      ies_args,node_coord,tot_sense_well,inversed_val):
    '''
    '''
    ax.set_title(fig_names[j*ies_args['inverse_day']+i]+f'第{i+1}天{feature_name[j]}{unit[j]}')
    pcm=[]
    
    ax.scatter(node_coord[:20,0],node_coord[:20,1],
                      c='white',marker='o',s=100,edgecolors='black',
                      linewidths=0.1,cmap='bwr')

    ax.scatter(node_coord[20:,0],node_coord[20:,1],
                      c='white',marker='*',s=150,edgecolors='black',
                      linewidths=0.1)
    
    for k in range(tot_sense_well.shape[1]):#对每一口井进行循环
        if tot_sense_well[i,k,j]<20:
            pcm.append(ax.scatter(node_coord[tot_sense_well[i,k,j],0],
                                        node_coord[tot_sense_well[i,k,j],1],
                                        c=inversed_val[j][i,k],marker='o',s=100,edgecolors='black',
                                        linewidths=1,cmap='bwr',
                                        vmin=inversed_val[j].min(),vmax=inversed_val[j].max()))
        else:
            pcm.append(ax.scatter(node_coord[tot_sense_well[i,k,j],0],
                                        node_coord[tot_sense_well[i,k,j],1],
                                        c=inversed_val[j][i,k],marker='*',s=150,edgecolors='black',
                                        linewidths=1,cmap='bwr',
                                        vmin=inversed_val[j].min(),vmax=inversed_val[j].max()))
    cbar=fig.colorbar(pcm[k], ax=ax)
    cbar.formatter.set_powerlimits((0, 0))
    cbar.update_ticks()
    
    
def plot_ensemble(inversed_flux,inversed_H,tot_sense_well,node_coord,
                  ies_args,criterion1,path):

    fig_names=[f'({chr(2*i+97)})' for i in range(ies_args['inverse_day'])]
    fig_names.extend([f'({chr(2*i-1+97)})' for i in range(1,ies_args['inverse_day']+1)])
    feature_name=['流量','酸量']
    unit=['(m$^{3}$/d)','(g/L)']
    node_coord=np.array(node_coord[:,1:],dtype=np.float32)
    inversed_val=[inversed_flux,inversed_H]
    tot_sense_well=np.stack([tot_sense_well[:,:5,0],tot_sense_well[:,5:,0]],axis=-1)
    
    fig,axs = plt.subplots(ies_args['inverse_day']//2+1,4,
                           figsize=(20,5*(ies_args['inverse_day']//2+1)),
                           dpi=200,sharex=True, sharey=True)
    
    for i in range(0,ies_args['inverse_day']//2):#

        for j in range(tot_sense_well.shape[-1]):
            plot_sig_ensemble(fig,axs[i,j],fig_names,feature_name,unit,2*i,j,
                              ies_args,node_coord,tot_sense_well,inversed_val)

        for j in range(tot_sense_well.shape[-1]):
            plot_sig_ensemble(fig,axs[i,j+2],fig_names,feature_name,unit,2*i+1,j,
                              ies_args,node_coord,tot_sense_well,inversed_val)
    

    for j in range(tot_sense_well.shape[-1]):
        plot_sig_ensemble(fig,axs[i+1,j+1],fig_names,feature_name,unit,2*i+2,j,
                          ies_args,node_coord,tot_sense_well,inversed_val)

    axs[i+1,0].axis('off')
    axs[i+1,-1].axis('off')
    path_t=path+f'/ensemble{criterion1}'
    plt.savefig(path_t, bbox_inches='tight',dpi=200)
    plt.clf()
    plt.close()






